# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Decoder registers and factory method.

One can register a new decoder model by the following two steps:

1 Import the factory and register the build in the decoder file.
2 Import the decoder class and add a build in __init__.py.

```
# my_decoder.py

from modeling.decoders import factory

class MyDecoder():
  ...

@factory.register_decoder_builder('my_decoder')
def build_my_decoder():
  return MyDecoder()

# decoders/__init__.py adds import
from modeling.decoders.my_decoder import MyDecoder
```

If one wants the MyDecoder class to be used only by those binary
then don't imported the decoder module in decoders/__init__.py, but import it
in place that uses it.
"""
from typing import Any, Callable, Mapping, Optional, Union

import tensorflow as tf, tf_keras

from official.core import registry
from official.modeling import hyperparams

_REGISTERED_DECODER_CLS = {}


def register_decoder_builder(key: str) -> Callable[..., Any]:
  """Decorates a builder of decoder class.

  The builder should be a Callable (a class or a function).
  This decorator supports registration of decoder builder as follows:

  ```
  class MyDecoder(tf_keras.Model):
    pass

  @register_decoder_builder('mydecoder')
  def builder(input_specs, config, l2_reg):
    return MyDecoder(...)

  # Builds a MyDecoder object.
  my_decoder = build_decoder_3d(input_specs, config, l2_reg)
  ```

  Args:
    key: A `str` of key to look up the builder.

  Returns:
    A callable for using as class decorator that registers the decorated class
    for creation from an instance of task_config_cls.
  """
  return registry.register(_REGISTERED_DECODER_CLS, key)


@register_decoder_builder('identity')
def build_identity(
    input_specs: Optional[Mapping[str, tf.TensorShape]] = None,
    model_config: Optional[hyperparams.Config] = None,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None) -> None:
  """Builds identity decoder from a config.

  All the input arguments are not used by identity decoder but kept here to
  ensure the interface is consistent.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A `OneOfConfig` of model config.
    l2_regularizer: A `tf_keras.regularizers.Regularizer` object. Default to
      None.

  Returns:
    An instance of the identity decoder.
  """
  del input_specs, model_config, l2_regularizer  # Unused by identity decoder.


def build_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: tf_keras.regularizers.Regularizer = None,
    **kwargs) -> Union[None, tf_keras.Model, tf_keras.layers.Layer]:  # pytype: disable=annotation-type-mismatch  # typed-keras
  """Builds decoder from a config.

  A decoder can be a keras.Model, a keras.layers.Layer, or None. If it is not
  None, the decoder will take features from the backbone as input and generate
  decoded feature maps. If it is None, such as an identity decoder, the decoder
  is skipped and features from the backbone are regarded as model output.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A `OneOfConfig` of model config.
    l2_regularizer: A `tf_keras.regularizers.Regularizer` object. Default to
      None.
    **kwargs: Additional keyword args to be passed to decoder builder.

  Returns:
    An instance of the decoder.
  """
  decoder_builder = registry.lookup(_REGISTERED_DECODER_CLS,
                                    model_config.decoder.type)

  return decoder_builder(
      input_specs=input_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer,
      **kwargs)
